#!/usr/bin/env python3

if __name__ == '__main__':
    import sys
    arguments = sys.argv[1:]
    hidden_size = int(arguments[0])
    if len(arguments) > 1:
        multiple_of = int(arguments[1])
    else:
        multiple_of = 64

    intermediate_size = 4 * hidden_size
    intermediate_size = int(2 * intermediate_size / 3)
    intermediate_size = multiple_of * ((intermediate_size + multiple_of - 1) // multiple_of)
    print(f'intermediate_size={intermediate_size}')
